{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "220e6bbb-c0c0-4f5d-9a63-bf2de0d39c67",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-10-23T02:21:14.910438Z",
     "start_time": "2024-10-23T02:21:13.357170700Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('H:\\python\\KAN\\KAN_VAE_MNIST-main\\KAN_VAE_MNIST-main\\efficient-kan-master')\n",
    "from src.efficient_kan import KAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f08a8838",
   "metadata": {
    "papermill": {
     "duration": 5.790567,
     "end_time": "2024-06-27T02:45:29.465864",
     "exception": false,
     "start_time": "2024-06-27T02:45:23.675297",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:21:20.642477500Z",
     "start_time": "2024-10-23T02:21:18.958467900Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\admin\\AppData\\Roaming\\Python\\Python310\\site-packages\\pandas\\core\\arrays\\masked.py:60: UserWarning: Pandas requires version '1.3.6' or newer of 'bottleneck' (version '1.3.5' currently installed).\n",
      "  from pandas.core import (\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torchvision\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms as tfs\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d7452a0b",
   "metadata": {
    "papermill": {
     "duration": 1.908979,
     "end_time": "2024-06-27T02:45:31.378366",
     "exception": false,
     "start_time": "2024-06-27T02:45:29.469387",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:22.197592400Z",
     "start_time": "2024-10-23T02:36:22.150269100Z"
    }
   },
   "outputs": [],
   "source": [
    "im_tfs = tfs.Compose([\n",
    "    tfs.ToTensor(),\n",
    "    tfs.Normalize((0.5, ), (0.5,))\n",
    "#     tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 标准化\n",
    "])\n",
    "train_set = torchvision.datasets.MNIST(\n",
    "    root=\"./mnist\", train=True, download=False, transform=im_tfs\n",
    ")\n",
    "val_set = torchvision.datasets.MNIST(\n",
    "    root=\"./mnist\", train=False, download=False, transform=im_tfs\n",
    ")\n",
    "\n",
    "# train_set = MNIST('/kaggle/working/mnist', transform=im_tfs)\n",
    "train_data = DataLoader(train_set, batch_size=128, shuffle=True)\n",
    "test_data=DataLoader(val_set,batch_size=128,shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3d42662b",
   "metadata": {
    "papermill": {
     "duration": 0.019369,
     "end_time": "2024-06-27T02:45:31.403148",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.383779",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:25.926918700Z",
     "start_time": "2024-10-23T02:36:25.916915Z"
    }
   },
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(VAE, self).__init__()\n",
    "        self.fc1 = KAN([784, 400])\n",
    "        self.fc21 = KAN([400, 20])  # mean\n",
    "        self.fc22 = KAN([400, 20])  # var\n",
    "        self.fc3 = KAN([20, 400])\n",
    "        self.fc4 = KAN([400, 784])\n",
    "        # self.fc1 = nn.Linear(784, 400)\n",
    "        # self.fc21 = nn.Linear(400, 20) # mean\n",
    "        # self.fc22 = nn.Linear(400, 20) # var\n",
    "        # self.fc3 = nn.Linear(20, 400)\n",
    "        # self.fc4 = nn.Linear(400, 784)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparametrize(self, mu, logvar):\n",
    "        std = logvar.mul(0.5).exp_()\n",
    "        eps = torch.FloatTensor(std.size()).normal_()\n",
    "        if torch.cuda.is_available():\n",
    "            eps = Variable(eps.cuda())\n",
    "        else:\n",
    "            eps = Variable(eps)\n",
    "        return eps.mul(std).add_(mu)\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return F.tanh(self.fc4(h3))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x) # 编码\n",
    "        z = self.reparametrize(mu, logvar) # 重新参数化成正态分布\n",
    "        return self.decode(z), mu, logvar # 解码，同时输出均值方差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c892e68d",
   "metadata": {
    "papermill": {
     "duration": 0.225375,
     "end_time": "2024-06-27T02:45:31.633817",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.408442",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:28.811648100Z",
     "start_time": "2024-10-23T02:36:28.598853100Z"
    }
   },
   "outputs": [],
   "source": [
    "net = VAE() # 实例化网络\n",
    "if torch.cuda.is_available():\n",
    "    net = net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95cbbf69",
   "metadata": {
    "papermill": {
     "duration": 0.561288,
     "end_time": "2024-06-27T02:45:32.200936",
     "exception": false,
     "start_time": "2024-06-27T02:45:31.639648",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:29.728974700Z",
     "start_time": "2024-10-23T02:36:29.704981Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0586,  0.0590, -0.0340, -0.0451, -0.0392, -0.0245,  0.0466, -0.0556,\n",
      "          0.0332,  0.0576,  0.0691, -0.0247,  0.0338, -0.0292,  0.0017, -0.0176,\n",
      "         -0.0487,  0.0050,  0.0121,  0.0185]], grad_fn=<ViewBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\admin\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\nn\\functional.py:1956: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.\n",
      "  warnings.warn(\"nn.functional.tanh is deprecated. Use torch.tanh instead.\")\n"
     ]
    }
   ],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, var = net(x)\n",
    "print(mu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "fc8ab604",
   "metadata": {
    "papermill": {
     "duration": 0.022903,
     "end_time": "2024-06-27T02:45:32.229708",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.206805",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:31.876278500Z",
     "start_time": "2024-10-23T02:36:31.869289Z"
    }
   },
   "outputs": [],
   "source": [
    "reconstruction_function = nn.MSELoss(reduction='sum')\n",
    "\n",
    "def loss_function(recon_x, x, mu, logvar):\n",
    "    \"\"\"\n",
    "    recon_x: generating images\n",
    "    x: origin images\n",
    "    mu: latent mean\n",
    "    logvar: latent log variance\n",
    "    \"\"\"\n",
    "    MSE = reconstruction_function(recon_x, x)\n",
    "    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n",
    "    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n",
    "    KLD = torch.sum(KLD_element).mul_(-0.5)\n",
    "    # KL divergence\n",
    "    return MSE + KLD\n",
    "\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "81553358",
   "metadata": {
    "papermill": {
     "duration": 0.015467,
     "end_time": "2024-06-27T02:45:32.251280",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.235813",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:36:37.580276Z",
     "start_time": "2024-10-23T02:36:37.573468400Z"
    }
   },
   "outputs": [],
   "source": [
    "def to_img(x):\n",
    "    '''\n",
    "    定义一个函数将最后的结果转换回图片\n",
    "    '''\n",
    "    x = 0.5 * (x + 1.)\n",
    "    x = x.clamp(0, 1)\n",
    "    x = x.view(x.shape[0], 1, 28, 28)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "63b68158",
   "metadata": {
    "papermill": {
     "duration": 1477.573568,
     "end_time": "2024-06-27T03:10:09.831022",
     "exception": false,
     "start_time": "2024-06-27T02:45:32.257454",
     "status": "completed"
    },
    "tags": [],
    "ExecuteTime": {
     "end_time": "2024-10-23T02:48:45.986112800Z",
     "start_time": "2024-10-23T02:36:45.491944800Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, Loss: 92.7421\n",
      "epoch: 2, Loss: 76.6302\n",
      "epoch: 3, Loss: 79.4802\n",
      "epoch: 4, Loss: 64.8842\n",
      "epoch: 5, Loss: 68.1174\n",
      "训练用时720.4846863746643\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "for e in range(5):\n",
    "    for im, _ in train_data:\n",
    "        im = im.view(im.shape[0], -1)\n",
    "        im = Variable(im)\n",
    "        if torch.cuda.is_available():\n",
    "            im = im.cuda()\n",
    "        recon_im, mu, logvar = net(im)\n",
    "        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 将 loss 平均\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    if (e + 1) % 1 == 0:\n",
    "        print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item()))\n",
    "        save = to_img(recon_im.cpu().data)\n",
    "        if not os.path.exists('./vae_img'):\n",
    "            os.mkdir('./vae_img')\n",
    "        save_image(save, './vae_img/kan_image_{}.png'.format(e + 1))\n",
    "        \n",
    "        \n",
    "end = time.time()\n",
    "print(f\"训练用时{end-start}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4945a186",
   "metadata": {
    "papermill": {
     "duration": 0.020103,
     "end_time": "2024-06-27T03:10:09.857686",
     "exception": false,
     "start_time": "2024-06-27T03:10:09.837583",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.5430, -0.6629,  0.6543,  0.5544, -0.8801,  0.6379,  1.5880,  0.1135,\n",
      "          0.2006, -0.4661,  0.5582,  1.2552, -1.2923, -0.7195, -0.2301, -0.3854,\n",
      "          0.0798, -0.6304, -0.5992,  1.8669]], grad_fn=<ViewBackward0>)\n"
     ]
    }
   ],
   "source": [
    "x, _ = train_set[0]\n",
    "x = x.view(x.shape[0], -1)\n",
    "if torch.cuda.is_available():\n",
    "    x = x.cuda()\n",
    "x = Variable(x)\n",
    "_, mu, _ = net(x)\n",
    "print(mu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "506d7456-4303-4e2a-b574-950fe7e11e24",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x2628d880d00>"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAd/ElEQVR4nO3df2xV9f3H8ddtoRfE9mKp/XHXggUVnPxYZNAxFFEafiwhoiQT9Q9YjEZWzLBzGhYVdT+6scQZtw6XbIGZiDoXgWgyEilS4lYwVJGRuY42dYDQMrr03lKkhfbz/YN4v175+Tnc23dbno/kJL33nnfPm8PpffX0nvu+IeecEwAAfSzDugEAwJWJAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAICJIdYNfFVvb68OHz6s7OxshUIh63YAAJ6cc+ro6FA0GlVGxvnPc/pdAB0+fFglJSXWbQAALtPBgwdVXFx83sf73Z/gsrOzrVsAAKTAxZ7P0xZA1dXVuu666zRs2DCVlZXpgw8+uKQ6/uwGAIPDxZ7P0xJAb7zxhiorK7V69Wp9+OGHmjJliubNm6ejR4+mY3MAgIHIpcH06dNdRUVF4nZPT4+LRqOuqqrqorWxWMxJYmFhYWEZ4EssFrvg833Kz4C6u7tVX1+v8vLyxH0ZGRkqLy9XXV3dWet3dXUpHo8nLQCAwS/lAXTs2DH19PSooKAg6f6CggK1tLSctX5VVZUikUhi4Qo4ALgymF8Ft2rVKsViscRy8OBB65YAAH0g5e8DysvLU2ZmplpbW5Pub21tVWFh4Vnrh8NhhcPhVLcBAOjnUn4GlJWVpalTp6qmpiZxX29vr2pqajRjxoxUbw4AMEClZRJCZWWlli5dqm9+85uaPn26XnzxRXV2dup73/teOjYHABiA0hJA9957r/773//qmWeeUUtLi77xjW9oy5YtZ12YAAC4coWcc866iS+Lx+OKRCLWbQAALlMsFlNOTs55Hze/Cg4AcGUigAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYGKIdQMYuEKhUJ/UOOf6pAaXJ8j/bUaG/+/AfXUMSVJPT0+gOlwazoAAACYIIACAiZQH0LPPPqtQKJS0TJgwIdWbAQAMcGl5Dejmm2/W1q1b/38jQ3ipCQCQLC3JMGTIEBUWFqbjWwMABom0vAa0f/9+RaNRjR07Vg888IAOHDhw3nW7uroUj8eTFgDA4JfyACorK9P69eu1ZcsWrV27Vs3NzbrtttvU0dFxzvWrqqoUiUQSS0lJSapbAgD0QyGX5jdMtLe3a8yYMXrhhRf04IMPnvV4V1eXurq6Erfj8TghNEDwPiB8Ge8DwlfFYjHl5OSc9/G0Xx0wcuRI3XjjjWpsbDzn4+FwWOFwON1tAAD6mbS/D+j48eNqampSUVFRujcFABhAUh5Ajz/+uGpra/Xpp5/q73//u+6++25lZmbqvvvuS/WmAAADWMr/BHfo0CHdd999amtr07XXXqtbb71VO3fu1LXXXpvqTQEABrC0X4TgKx6PKxKJWLeRckFeOA0iyH9nkBeCgwrypuTe3t40dJI6ffVCdV/+qAY5JoYNG+ZdM2rUKO+azs5O75rzXYV7MadOnQpUhzMudhECs+AAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYSPsH0uGMfjbzNSWysrK8a4JMRQ+ynezsbO8aSfrf//7XJzXd3d3eNUEEHZQa5EMib7nlFu+au+++27vm97//vXdNkAGmSD/OgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJpiGjcCTukOhkHfN6NGjvWsWLVrkXTNu3DjvGknav3+/d83mzZu9a44dO+ZdE2SyddCp27fffrt3TWVlpXfNmDFjvGu2bt3qXXPw4EHvGinYMT4YJ9+nC2dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATDCMFMrICPZ7SGlpqXfNz372M++aSZMmeddkZWV510jSTTfd5F0zZIj/j1GQYaS7d+/2runo6PCukaSFCxd610yePNm7Jsiwz/Lycu+a9957z7tGYrBounEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSAMIMkCxr4YaBuktyDBNSVqwYIF3zZQpU7xrhg0b5l1z4MAB7xpJ+stf/uJds3PnTu+atrY275rGxkbvmkgk4l0jSVdddZV3TdChtr5yc3O9a06fPp2GTnC5OAMCAJgggAAAJrwDaMeOHVq4cKGi0ahCoZA2bdqU9LhzTs8884yKioo0fPhwlZeXa//+/anqFwAwSHgHUGdnp6ZMmaLq6upzPr5mzRq99NJLevnll7Vr1y6NGDFC8+bN08mTJy+7WQDA4OH96vOCBQvO++Kzc04vvviinnrqKd11112SpFdeeUUFBQXatGmTlixZcnndAgAGjZS+BtTc3KyWlpakj8yNRCIqKytTXV3dOWu6uroUj8eTFgDA4JfSAGppaZEkFRQUJN1fUFCQeOyrqqqqFIlEEktJSUkqWwIA9FPmV8GtWrVKsVgssRw8eNC6JQBAH0hpABUWFkqSWltbk+5vbW1NPPZV4XBYOTk5SQsAYPBLaQCVlpaqsLBQNTU1ifvi8bh27dqlGTNmpHJTAIABzvsquOPHjyeNBGlubtaePXuUm5ur0aNHa+XKlfrpT3+qG264QaWlpXr66acVjUa1aNGiVPYNABjgvANo9+7duuOOOxK3KysrJUlLly7V+vXr9cQTT6izs1MPP/yw2tvbdeutt2rLli2B5nkBAAYv7wCaPXv2BQdrhkIhPf/883r++ecvqzEEE2QYaXFxcaBtPf744941QX4R6ezs9K7ZuHGjd40k/eEPf/Cu+eyzz7xrggynDVLT29vrXSNJI0aM8K4JcuwF6S8cDnvX9NUwYPgxvwoOAHBlIoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCY8J6Gjf49WTczM9O75rvf/W6gbY0cOdK7Jsj04y9/wOGl+u1vf+tdI0mHDx/2rgk6cdpXkGnT0Wg00LZuuOEG75ogPxfd3d3eNR9++KF3TZCfC0nq6ekJVIdLwxkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwwjDSDIUMi+GmAaiUS8a+68885A2woy4DHIfjh06JB3TdABoRkZffM7WZDtFBUVedf8/Oc/966RpNzcXO+aID8XQYaRFhcXe9dcd9113jWS1NjY6F3TV8NpBwPOgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgGGkAfTWMNMh2ggwjDTooNUh/Q4b4H3JLlizxrrnlllu8aySpvr7eu6anp8e75pprrvGuKSsr864ZP368d40kZWVledcEGcJ5+vRp75rZs2d717S1tXnXSNK6deu8az777DPvmr4aVtzfcAYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIA+irwYFBhn12dHR417z11lveNZI0efJk75oRI0Z414waNcq7ZubMmd41knTrrbd61/TVcNrMzEzvmiC9ScGGhGZk9M3vsx9//LF3TWtra6BtBRmwikvHGRAAwAQBBAAw4R1AO3bs0MKFCxWNRhUKhbRp06akx5ctW6ZQKJS0zJ8/P1X9AgAGCe8A6uzs1JQpU1RdXX3edebPn68jR44kltdee+2ymgQADD7eFyEsWLBACxYsuOA64XBYhYWFgZsCAAx+aXkNaPv27crPz9f48eO1fPnyC34cbldXl+LxeNICABj8Uh5A8+fP1yuvvKKamhr98pe/VG1trRYsWKCenp5zrl9VVaVIJJJYSkpKUt0SAKAfSvn7gJYsWZL4etKkSZo8ebLGjRun7du3a86cOWetv2rVKlVWViZux+NxQggArgBpvwx77NixysvLU2Nj4zkfD4fDysnJSVoAAINf2gPo0KFDamtrU1FRUbo3BQAYQLz/BHf8+PGks5nm5mbt2bNHubm5ys3N1XPPPafFixersLBQTU1NeuKJJ3T99ddr3rx5KW0cADCweQfQ7t27dccddyRuf/H6zdKlS7V27Vrt3btXf/rTn9Te3q5oNKq5c+fqJz/5icLhcOq6BgAMeCHXV5M1L1E8HlckErFuY8DKysryrsnNzQ20rby8PO+ai72H7FxKS0u9a6ZNm+ZdI0k33nijd02QH6GhQ4d61wT5Je7zzz/3rpHO/KWjL7a1bds275pXX33Vu6apqcm7Rgo2xLSrqyvQtgajWCx2wdf1mQUHADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBNGwEFgqF+qQmMzPTuyaoq6++2rvm61//unfN6tWrvWsmTpzoXfP2229710jSpk2bvGuKi4u9az755BPvmiCTrdvb271rJOnkyZPeNf3sKdUU07ABAP0SAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0OsG8DAFWToYpCa3t5e75qgggytbG1t9a6Jx+PeNUEGd77yyiveNZL08ccfe9eMGjXKuybIsM8g/0fd3d3eNRKDRdONMyAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmGEYKXKaZM2d614wfP9675oMPPvCu+fTTT71rJOn06dPeNeFw2LsmFAp51xw7dsy7hqGi/RNnQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwjBT4kqFDh3rXfPvb3/au6ezs9K7Zt2+fd013d7d3TVBDhvg/nQQZehpkgCn6J86AAAAmCCAAgAmvAKqqqtK0adOUnZ2t/Px8LVq0SA0NDUnrnDx5UhUVFRo1apSuvvpqLV68WK2trSltGgAw8HkFUG1trSoqKrRz5069++67OnXqlObOnZv09+zHHntMb7/9tt58803V1tbq8OHDuueee1LeOABgYPN61XDLli1Jt9evX6/8/HzV19dr1qxZisVi+uMf/6gNGzbozjvvlCStW7dON910k3bu3KlvfetbqescADCgXdZrQLFYTJKUm5srSaqvr9epU6dUXl6eWGfChAkaPXq06urqzvk9urq6FI/HkxYAwOAXOIB6e3u1cuVKzZw5UxMnTpQktbS0KCsrSyNHjkxat6CgQC0tLef8PlVVVYpEIomlpKQkaEsAgAEkcABVVFRo3759ev311y+rgVWrVikWiyWWgwcPXtb3AwAMDIHeiLpixQq988472rFjh4qLixP3FxYWqru7W+3t7UlnQa2trSosLDzn9wqHwwqHw0HaAAAMYF5nQM45rVixQhs3btS2bdtUWlqa9PjUqVM1dOhQ1dTUJO5raGjQgQMHNGPGjNR0DAAYFLzOgCoqKrRhwwZt3rxZ2dnZidd1IpGIhg8frkgkogcffFCVlZXKzc1VTk6OHn30Uc2YMYMr4AAASbwCaO3atZKk2bNnJ92/bt06LVu2TJL061//WhkZGVq8eLG6uro0b948/e53v0tJswCAwcMrgJxzF11n2LBhqq6uVnV1deCmACvZ2dneNZfyc/FVHR0d3jVNTU19sh0p2L8pyCDXIINFe3t7vWvQPzELDgBgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgItAnogL9XUZGsN+tioqKvGuCTGc+3ycEX0gkEvGuycnJ8a6RpCFD/J8aCgoKvGv+/e9/e9dkZmZ61zBBu3/iDAgAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJhpFiUAo6jDTIoMsgNV1dXd41t912m3fNP/7xD+8aSTp27Jh3zYkTJ7xrggwJdc5514RCIe+avtxWkO0MBpwBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMMEwUgxKPT09gepaWlq8a4IM1Ozo6PCuqa2t9a5pa2vzrpGkeDzuXdPc3OxdE4vFvGuCDhZF/8MZEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMh55yzbuLL4vG4IpGIdRu4Qg0bNsy75o477vCu6e7u9q75+OOPvWva29u9ayQpyNNCVlaWd02QobGnT5/2rgkyMFYKNvi0nz2lmorFYsrJyTnv45wBAQBMEEAAABNeAVRVVaVp06YpOztb+fn5WrRokRoaGpLWmT17tkKhUNLyyCOPpLRpAMDA5xVAtbW1qqio0M6dO/Xuu+/q1KlTmjt3rjo7O5PWe+ihh3TkyJHEsmbNmpQ2DQAY+Lw+EXXLli1Jt9evX6/8/HzV19dr1qxZifuvuuoqFRYWpqZDAMCgdFmvAX3xcbq5ublJ97/66qvKy8vTxIkTtWrVKp04ceK836Orq0vxeDxpAQAMfl5nQF/W29urlStXaubMmZo4cWLi/vvvv19jxoxRNBrV3r179eSTT6qhoUFvvfXWOb9PVVWVnnvuuaBtAAAGqMDvA1q+fLn++te/6v3331dxcfF519u2bZvmzJmjxsZGjRs37qzHu7q61NXVlbgdj8dVUlISpCXgsvE+oDN4H9AZvA/o8lzsfUCBzoBWrFihd955Rzt27Lhg+EhSWVmZJJ03gMLhsMLhcJA2AAADmFcAOef06KOPauPGjdq+fbtKS0svWrNnzx5JUlFRUaAGAQCDk1cAVVRUaMOGDdq8ebOys7PV0tIiSYpEIho+fLiampq0YcMGfec739GoUaO0d+9ePfbYY5o1a5YmT56cln8AAGBg8gqgtWvXSjrzZtMvW7dunZYtW6asrCxt3bpVL774ojo7O1VSUqLFixfrqaeeSlnDAIDBwftPcBdSUlKi2tray2oIAHBlCHwZNjAYBblaqq6uzrsmyNVfx48f967pyyuyTp486V2TkeH/VsSgV7QFwRVt6cUwUgCACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRgp8SZCPyg5SMxgFGdwZZCgrBg/OgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgot8FUJB5UgCA/udiz+f9LoA6OjqsWwAApMDFns9Drp+dcvT29urw4cPKzs5WKBRKeiwej6ukpEQHDx5UTk6OUYf22A9nsB/OYD+cwX44oz/sB+ecOjo6FI1GlZFx/vOcfvdxDBkZGSouLr7gOjk5OVf0AfYF9sMZ7Icz2A9nsB/OsN4PkUjkouv0uz/BAQCuDAQQAMDEgAqgcDis1atXKxwOW7diiv1wBvvhDPbDGeyHMwbSfuh3FyEAAK4MA+oMCAAweBBAAAATBBAAwAQBBAAwMWACqLq6Wtddd52GDRumsrIyffDBB9Yt9blnn31WoVAoaZkwYYJ1W2m3Y8cOLVy4UNFoVKFQSJs2bUp63DmnZ555RkVFRRo+fLjKy8u1f/9+m2bT6GL7YdmyZWcdH/Pnz7dpNk2qqqo0bdo0ZWdnKz8/X4sWLVJDQ0PSOidPnlRFRYVGjRqlq6++WosXL1Zra6tRx+lxKfth9uzZZx0PjzzyiFHH5zYgAuiNN95QZWWlVq9erQ8//FBTpkzRvHnzdPToUevW+tzNN9+sI0eOJJb333/fuqW06+zs1JQpU1RdXX3Ox9esWaOXXnpJL7/8snbt2qURI0Zo3rx5OnnyZB93ml4X2w+SNH/+/KTj47XXXuvDDtOvtrZWFRUV2rlzp959912dOnVKc+fOVWdnZ2Kdxx57TG+//bbefPNN1dbW6vDhw7rnnnsMu069S9kPkvTQQw8lHQ9r1qwx6vg83AAwffp0V1FRkbjd09PjotGoq6qqMuyq761evdpNmTLFug1TktzGjRsTt3t7e11hYaH71a9+lbivvb3dhcNh99prrxl02De+uh+cc27p0qXurrvuMunHytGjR50kV1tb65w7838/dOhQ9+abbybW+eSTT5wkV1dXZ9Vm2n11Pzjn3O233+5+8IMf2DV1Cfr9GVB3d7fq6+tVXl6euC8jI0Pl5eWqq6sz7MzG/v37FY1GNXbsWD3wwAM6cOCAdUummpub1dLSknR8RCIRlZWVXZHHx/bt25Wfn6/x48dr+fLlamtrs24prWKxmCQpNzdXklRfX69Tp04lHQ8TJkzQ6NGjB/Xx8NX98IVXX31VeXl5mjhxolatWqUTJ05YtHde/W4Y6VcdO3ZMPT09KigoSLq/oKBA//rXv4y6slFWVqb169dr/PjxOnLkiJ577jnddttt2rdvn7Kzs63bM9HS0iJJ5zw+vnjsSjF//nzdc889Ki0tVVNTk3784x9rwYIFqqurU2ZmpnV7Kdfb26uVK1dq5syZmjhxoqQzx0NWVpZGjhyZtO5gPh7OtR8k6f7779eYMWMUjUa1d+9ePfnkk2poaNBbb71l2G2yfh9A+H8LFixIfD158mSVlZVpzJgx+vOf/6wHH3zQsDP0B0uWLEl8PWnSJE2ePFnjxo3T9u3bNWfOHMPO0qOiokL79u27Il4HvZDz7YeHH3448fWkSZNUVFSkOXPmqKmpSePGjevrNs+p3/8JLi8vT5mZmWddxdLa2qrCwkKjrvqHkSNH6sYbb1RjY6N1K2a+OAY4Ps42duxY5eXlDcrjY8WKFXrnnXf03nvvJX18S2Fhobq7u9Xe3p60/mA9Hs63H86lrKxMkvrV8dDvAygrK0tTp05VTU1N4r7e3l7V1NRoxowZhp3ZO378uJqamlRUVGTdipnS0lIVFhYmHR/xeFy7du264o+PQ4cOqa2tbVAdH845rVixQhs3btS2bdtUWlqa9PjUqVM1dOjQpOOhoaFBBw4cGFTHw8X2w7ns2bNHkvrX8WB9FcSleP311104HHbr1693//znP93DDz/sRo4c6VpaWqxb61M//OEP3fbt211zc7P729/+5srLy11eXp47evSodWtp1dHR4T766CP30UcfOUnuhRdecB999JH7z3/+45xz7he/+IUbOXKk27x5s9u7d6+76667XGlpqfv888+NO0+tC+2Hjo4O9/jjj7u6ujrX3Nzstm7d6m655RZ3ww03uJMnT1q3njLLly93kUjEbd++3R05ciSxnDhxIrHOI4884kaPHu22bdvmdu/e7WbMmOFmzJhh2HXqXWw/NDY2uueff97t3r3bNTc3u82bN7uxY8e6WbNmGXeebEAEkHPO/eY3v3GjR492WVlZbvr06W7nzp3WLfW5e++91xUVFbmsrCz3ta99zd17772usbHRuq20e++995yks5alS5c6585civ3000+7goICFw6H3Zw5c1xDQ4Nt02lwof1w4sQJN3fuXHfttde6oUOHujFjxriHHnpo0P2Sdq5/vyS3bt26xDqff/65+/73v++uueYad9VVV7m7777bHTlyxK7pNLjYfjhw4ICbNWuWy83NdeFw2F1//fXuRz/6kYvFYraNfwUfxwAAMNHvXwMCAAxOBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATPwfhZD5A9Xr+F0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "code = Variable(torch.randn(1, 20)) # 给定一个编码观察生成的图片\n",
    "decode = net.decode(code)\n",
    "decode_img = to_img(decode).squeeze()\n",
    "decode_img = decode_img.data.numpy() * 255\n",
    "plt.imshow(decode_img.astype('uint8'), cmap='gray') # 生成图片 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c7491a-95d6-4eff-b7d0-c8aa58b2fd14",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "none",
   "dataSources": [],
   "dockerImageVersionId": 30733,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3 (ipykernel)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 1491.585132,
   "end_time": "2024-06-27T03:10:11.391316",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2024-06-27T02:45:19.806184",
   "version": "2.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
